#!/usr/bin/env python3
import argparse
import boto3
import contextlib
import copy
import getpass
import io
import json
import os
import shutil
import subprocess
import sys
import tempfile
import threading
import time
import yaml

def main(args):
    command, args = parse_command(args)
    function = get_command_function(command)
    function(args)

def log(level, msg, *args):
    msg = level + ': ' + msg + '\n'
    if args:
        msg %= args
    sys.stdout.write(msg)

def warn(msg, *args):
    log('warn', msg, *args)

def error(msg, *args):
    log('error', msg, *args)
    sys.exit(1)

def usage(msg, *args):
    log('usage', msg, *args)
    sys.exit(1)

def get_command_function(command):
    function = COMMANDS.get(command)
    if function is None:
        usage('invalid command: packer %s ...', command)
        sys.exit(1)
    return function

def parse_command(args):
    if len(args) == 0:
        usage('missing command: pack [command] ...')
        sys.exit(1)
    return args[0], args[1:]

def parse_build_arguments(args):
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--template', metavar='NAME', required=True, help="The name of the template to build")
    parser.add_argument('-p', '--path', metavar='PATH', default='.', help="The path to the folder containing the templates")
    parser.add_argument('-r', '--recurse', action='store_true', help="Recursively build images depending on the template")
    parser.add_argument('-s', '--setting', action='append', help="Add variables to the build")
    return parser.parse_args(args)

def parse_validate_arguments(args):
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--template', metavar='NAME', required=True, help="The name of the template to validate")
    parser.add_argument('-p', '--path', metavar='PATH', default='.', help="The path to the folder containing the templates")
    parser.add_argument('-r', '--recurse', action='store_true', help="Recursively validate images depending on the template")
    parser.add_argument('-s', '--setting', action='append', help="Add variables to the build")
    return parser.parse_args(args)

def parse_plan_arguments(args):
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--template', metavar='NAME', required=True, help="The name of the template to plan")
    parser.add_argument('-p', '--path', metavar='PATH', default='.', help="The path to the folder containing the pack templates")
    parser.add_argument('-r', '--recurse', action='store_true', help="Recursively plan images depending on the template")
    parser.add_argument('-s', '--setting', action='append', help="Add variables to the plan")
    return parser.parse_args(args)

def command_build(args):
    options = parse_build_arguments(args)

    for t in load_templates(options.path, settings=parse_settings(options.setting)):
        if t.name == options.template:
            if t.ami.get('source_ami') is None:
                fetch_source_ami(t)
            t.do('build', templates_dir=os.path.join(os.getcwd(), options.path), recurse=options.recurse)

def command_validate(args):
    options = parse_validate_arguments(args)

    for t in load_templates(options.path, settings=parse_settings(options.setting)):
        if t.name == options.template:
            if t.ami.get('source_ami') is None:
                fetch_source_ami(t)
            t.do('validate', templates_dir=os.path.join(os.getcwd(), options.path), recurse=options.recurse)

def command_plan(args):
    options = parse_plan_arguments(args)

    for t in load_templates(options.path, settings=parse_settings(options.setting)):
        if t.name == options.template:
            if t.ami.get('source_ami') is None:
                fetch_source_ami(t)
            print(json.dumps(t.plan(recurse=options.recurse), sort_keys=True, indent=2, separators=(',', ': ')))
            break

def parse_settings(settings):
    if settings is None:
        return { }
    return dict(split_setting(s) for s in settings)

def split_setting(s):
    k, v = s.split('=', 1)
    try:
        v = yaml.load(io.StringIO(v))
    except ValueError:
        pass
    return k, v

def fetch_source_ami(template):
    ami = get_latest_parent_ami(template)
    if ami is None:
        error('no source AMI available for %s, please build one of the parent AMI first (%s)', template.name, ', '.join(p.name for p in get_parent_templates(template)))
    template.ami['source_ami'] = ami

def get_parent_templates(template):
    parents = [ ]
    while template.parent is not None:
        parents.append(template.parent)
        template = template.parent
    return parents

def get_account_id():
    iam = boto3.client('iam')
    users = iam.list_users()['Users']
    roles = iam.list_roles()['Roles']

    if users:
        arn = users[0]['Arn']
    else:
        for role in roles:
            try:
                arn = iam.get_role(RoleName=role['RoleName'])['Role']['Arn']
                break
            except Exception as e:
                pass

    return arn.split(':')[4]

def get_latest_parent_ami(template):
    if not template.base:
        return
    template.log('looking up account id...')
    aid = get_account_id()

    ec2 = boto3.resource('ec2')
    base = template.base + '/'
    images = [ ]
    filters = [
        { 'Name': 'name', 'Values': [base + '*'] },
        { 'Name': 'owner-id', 'Values': [aid] },
    ]

    template.log('looking up parent AMI...')
    for img in ec2.images.filter(Filters=filters).all():
        name = img.name
        if name is not None and name.startswith(base):
            images.append(img)

    if images:
        images.sort(key=lambda img: img.name)
        return images[-1].id

def list_templates(path):
    for file_name in os.listdir(path):
        if not file_name.startswith('.') and os.path.isdir(os.path.join(path, file_name)):
            yield file_name

def load_templates(path, settings=None):
    templates = [ ]

    for file_name in list_templates(path):
        try:
            with open(os.path.join(path, file_name, 'packer.yml')) as f:
                pack = yaml.load(f)

                if 'ami' not in pack and 'base' not in pack:
                    warn("skipping %s - neither 'ami' nor 'base' properties exist", pack)
                    continue

                if 'ami' in pack and 'base' in pack:
                    warn("skipping %s - both 'ami' and 'base' properties exist", pack)
                    continue

                ami = pack.get('ami')
                if settings is not None:
                    if ami is None:
                        ami = { }
                    for k, v in settings.items():
                        ami[k] = v

                templates.append(Template(
                    name      = file_name,
                    ami       = ami,
                    base      = pack.get('base'),
                    scripts   = pack.get('scripts'),
                    variables = pack.get('variables'),
                    execute   = pack.get('execute'),
                ))
        except Exception as e:
            warn('skipping %s - none or invalid packer.yml file: %s', file_name, e)

    for t in templates:
        if t.base is not None:
            if not insert_template(templates, t):
                warn('skipping %s - no based template named %s was found', t.name, t.base)

    for t in templates:
        t.children.sort(key=lambda t: t.name)

    return templates

def insert_template(templates, t0):
    for t1 in templates:
        if t0.base == t1.name:
            t0.parent = t1
            t1.children.append(t0)
            return True
    return False

class Template(object):

    def __init__(self, name=None, ami=None, base=None, scripts=None, variables=None, execute=None):
        if ami is None:
            ami = { }

        if scripts is None:
            scripts = [ ]

        if variables is None:
            variables = { }

        self._variables = variables
        self._execute   = execute
        self._scripts   = scripts
        self.name       = name
        self.ami        = ami
        self.base       = base
        self.parent     = None
        self.children   = [ ]

        self.begin = Script(
            name    = 'pack-begin.sh',
            content = BEGIN_SCRIPT,
        )

        self.end = Script(
            name    = 'pack-end.sh',
            content = END_SCRIPT,
        )

    def __repr__(self):
        return '\n'.join(self._repr([ ]))

    def __str__(self):
        return self.name

    def _repr(self, parts, depth=0):
        parts.append(2 * depth * ' ' + self.name)
        depth += 1
        for c in self.children:
            c._repr(parts, depth)
        return parts

    def variables(self, source_ami=None):
        if self.parent is not None:
            variables = self.parent.variables()
        else:
            variables = { }
        for k, v in self._variables:
            variables[k] = copy.deepcopy(v)
        return variables

    def builders(self, source_ami=None):
        if self.parent is not None:
            builders = self.parent.builders()
        else:
            builders = [{ }]
        b = builders[0]
        b['source_ami'] = '<computed>'

        for k, v in self.ami.items():
            b[k] = copy.deepcopy(v)

        if 'region' not in b:
            for k in ('AWS_REGION', 'AWS_DEFAULT_REGION'):
                if k in os.environ:
                    b['region'] = os.environ[k]
                    break

        if source_ami is not None:
            b['source_ami'] = source_ami

        b['type'] = 'amazon-ebs'
        b['name'] = 'amazon-ebs'
        b['ami_name'] = '%s/%s' % (self.name, TIMESTAMP)
        return builders

    def provisioners(self):
        p0 = {
            'type'        : 'file',
            'source'      : 'root',
            'destination' : '/tmp',
        }

        p1 = { 'type': 'shell', 'scripts': self.scripts() }
        ex = self.execute_command()

        if ex is not None:
            p1['execute_command'] = ex

        return [p0, p1]

    def scripts(self):
        scripts = [ ]
        scripts.append(self.begin.name)
        for s in self._scripts:
            scripts.append('scripts/' + s)
        scripts.append(self.end.name)
        return scripts

    def execute_command(self):
        if self.parent is None:
            return self._execute
        return self.parent.execute_command()

    def plan_self(self, source_ami=None):
        return {
            'variables'    : self.variables(),
            'provisioners' : self.provisioners(),
            'builders'     : self.builders(source_ami=source_ami),
        }

    def plan(self, source_ami=None, recurse=False):
        if not recurse:
            return self.plan_self(source_ami=source_ami)
        tasks = [self]
        plans = [ ]

        while tasks:
            t, tasks = tasks[0], tasks[1:]
            tasks.extend(t.children)
            plans.append(t.plan_self(source_ami=source_ami))

        return plans

    def do(self, command, templates_dir=None, source_ami=None, recurse=False):
        ami = self.do_self(command, templates_dir=templates_dir, source_ami=source_ami)
        if not ami:
            ami = '<computed>'
        if recurse:
            self.do_children(command, templates_dir=templates_dir, source_ami=ami, recurse=recurse)

    def do_self(self, command, templates_dir=None, source_ami=None):
        if templates_dir is None:
            template_dir = '.'

        pwd = tempfile.mkdtemp()

        with open('/dev/null', 'rb') as stdin, \
             open(os.path.join(pwd, 'stdout.log'), 'wb') as stdout, \
             open(os.path.join(pwd, 'stderr.log'), 'wb') as stderr:

            for d in ('root', 'scripts'):
                shutil.copytree(os.path.join(templates_dir, self.name, d), os.path.join(pwd, d))

            for s in (self.begin, self.end):
                with open(os.path.join(pwd, s.name), 'w') as f:
                    f.write(s.content)

            with open(os.path.join(pwd, 'plan.json'), 'w') as f:
                f.write(json.dumps(self.plan_self(source_ami=source_ami), sort_keys=True, indent=2, separators=(',', ': ')))

            cmd = ['packer', command, 'plan.json']
            self.log('working directory is %s', pwd)
            self.log(' '.join(cmd))
            packer = subprocess.Popen(cmd, stdin=stdin, stdout=stdout, stderr=stderr, cwd=pwd)
            packer.wait()

            if packer.returncode != 0:
                raise Exception("Buliding %s failed: see logs in %s" % (self.name, pwd))

            try:
                ami = subprocess.check_output("grep ami- %s  | tail -n 1 | cut -d' ' -f2" % os.path.join(pwd, 'stdout.log'), shell=True)
                ami = str(ami, 'utf-8').strip()
                if ami:
                    self.log(ami)
                    return ami
            except:
                raise Exception("Reading AMI ID of %s failed: see logs in %s" % (self.name, pwd))

    def do_children(self, *args, **kwargs):
        threads = [ ]

        for c in self.children:
            t = threading.Thread(target=c.do, args=args, kwargs=kwargs, daemon=True)
            t.start()
            threads.append(t)

        for t in threads:
            t.join()

    def log(self, msg, *args):
        log('==> ' + self.name, msg, *args)

class Script(object):

    def __init__(self, name=None, content=None):
        self.name    = name
        self.content = content

    def __repr__(self):
        return 'Script { name = %s }' % self.name

    def __str__(self):
        return self.name

TIMESTAMP = time.strftime('%Y.%m.%d-%H.%M.%S', time.gmtime())

COMMANDS = {
    'build'   : command_build,
    'plan'    : command_plan,
    'validate': command_validate,
}

BEGIN_SCRIPT = """#!/bin/bash
set -e

if [ "$(ls -A /tmp/root)" ]
then
  cd /tmp/root
  cp -r * /
fi

systemctl daemon-reload
"""

END_SCRIPT = """#!/bin/bash
set -e

if [ "$(ls -A /tmp/root)" ]
then
  cd /tmp/root
  cp -r * /
fi

systemctl daemon-reload

apt-get -y autoremove
apt-get -y autoclean

rm -rf /tmp/*
rm -rf /var/tmp/*
rm -rf $HOME/.ssh/authorized_keys

for f in $(find /var/log -type f) ; do
  dd if=/dev/null of=$f
done
"""

if __name__ == '__main__':
    main(sys.argv[1:])
